{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "performance.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [
        "ISubpr_SSsiM"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.4"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ISubpr_SSsiM"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "3jTMb1dySr3V",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "6DWfyNThSziV"
      },
      "source": [
        "# tf.function で性能アップ\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/customization/performance\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/ja/tutorials/customization/performance.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/ja/tutorials/customization/performance.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/ja/tutorials/customization/performance.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J122XQYG7W6w"
      },
      "source": [
        "TensorFlow 2.0 では Eager Execution が既定で有効になっています。ユーザーインターフェイスは直感的で柔軟です（演算を一度だけ行う場合にはずっと簡単に、かつ迅速に実行されます）。しかしながら、それは性能と展開の面での犠牲の上に成り立っています。\n",
        "\n",
        "最高性能を得ながら、モデルをどこへでも展開できるようにするには、`tf.function` を使ってプログラムから計算グラフを作成します。\n",
        "AutoGraph のおかげで、驚くほど多くの Python コードが tf.function でそのまま動作しますが、気をつけなければならない落とし穴も存在します。\n",
        "\n",
        "ポイントと推奨事項は下記の通りです。\n",
        "\n",
        "- オブジェクトの変更やリストへの追加のような Python の副作用に依存しないこと\n",
        "- tf.functions は NumPy の演算や Python の組み込み演算よりも、TensorFlow の演算に適していること\n",
        "- 迷ったときは、`for x in y` というイディオムを使うこと"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "otIdN1TS8N7S",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "D25apou9IOXa",
        "colab": {}
      },
      "source": [
        "import contextlib\n",
        "\n",
        "# 遭遇するかもしれないいくつかのエラーをデモするためのヘルパー関数\n",
        "@contextlib.contextmanager\n",
        "def assert_raises(error_class):\n",
        "  try:\n",
        "    yield\n",
        "  except error_class as e:\n",
        "    print('Caught expected exception \\n  {}: {}'.format(error_class, e))\n",
        "  except Exception as e:\n",
        "    print('Got unexpected exception \\n  {}: {}'.format(type(e), e))\n",
        "  else:\n",
        "    raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
        "        error_class))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rfayNj-ZIkIB"
      },
      "source": [
        "あなたが定義した `tf.function` は TensorFlow Core の演算に似たものです。例えばそれを即時に実行することも、計算グラフで使うこともできますし、勾配を計算することも可能です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SbtT1-Wm70F2",
        "colab": {}
      },
      "source": [
        "# function は演算のように振る舞う\n",
        "\n",
        "@tf.function\n",
        "def add(a, b):\n",
        "  return a + b\n",
        "\n",
        "add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "uP-zUelB8DbX",
        "colab": {}
      },
      "source": [
        "# function は勾配を計算できる\n",
        "\n",
        "@tf.function\n",
        "def add(a, b):\n",
        "  return a + b\n",
        "\n",
        "v = tf.Variable(1.0)\n",
        "with tf.GradientTape() as tape:\n",
        "  result = add(v, 1.0)\n",
        "tape.gradient(result, v)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "l5qRjdbBVdU6",
        "colab": {}
      },
      "source": [
        "# function 内で function を使うこともできる\n",
        "\n",
        "@tf.function\n",
        "def dense_layer(x, w, b):\n",
        "  return add(tf.matmul(x, w), b)\n",
        "\n",
        "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uZ4Do2AV80cO"
      },
      "source": [
        "## トレーシングとポリモーフィズム\n",
        "\n",
        "Python の動的型付けは、関数をさまざまな型の引数で呼び出すことができ、Python がそれぞれのシナリオで異なる動作をするということを意味します。\n",
        "\n",
        "他方で、TensorFlow の計算グラフでは、dtype と shape の次元が静的であることが必要です。`tf.function` は、正しい計算グラフを生成するために必要なときには関数を再トレースして、このギャップをつなぐ役割を果たします。\n",
        "\n",
        "異なる型の引数を使って関数を呼び出し、何が起きるか見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kojmJrgq8U9v",
        "colab": {}
      },
      "source": [
        "# Function はポリモーフィック\n",
        "\n",
        "@tf.function\n",
        "def double(a):\n",
        "  print(\"Tracing with\", a)\n",
        "  return a + a\n",
        "\n",
        "print(double(tf.constant(1)))\n",
        "print()\n",
        "print(double(tf.constant(1.1)))\n",
        "print()\n",
        "print(double(tf.constant(\"a\")))\n",
        "print()\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4pJqkDR_Q2wz"
      },
      "source": [
        "トレースの動作を制御するためには、下記のようなテクニックを使います。\n",
        "\n",
        "- 新しい `tf.function` を作成する。別々の `tf.function` オブジェクトがトレースを共有することはない。\n",
        "- 特定のトレースを得るには `get_concrete_function` メソッドを使用する。\n",
        "- 計算グラフの呼び出し時に1回だけトレースを行うには、 `input_signature` を指定して `tf.function` を呼び出す。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "mHg2CGtPQ3Hz",
        "colab": {}
      },
      "source": [
        "print(\"Obtaining concrete trace\")\n",
        "double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))\n",
        "print(\"Executing traced function\")\n",
        "print(double_strings(tf.constant(\"a\")))\n",
        "print(double_strings(a=tf.constant(\"b\")))\n",
        "print(\"Using a concrete trace with incompatible types will throw an error\")\n",
        "with assert_raises(tf.errors.InvalidArgumentError):\n",
        "  double_strings(tf.constant(1))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_BDMIRmu1RGB",
        "colab": {}
      },
      "source": [
        "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
        "def next_collatz(x):\n",
        "  print(\"Tracing with\", x)\n",
        "  return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)\n",
        "\n",
        "print(next_collatz(tf.constant([1, 2])))\n",
        "# 1次元のテンソルを input signature として指定しているので、これは失敗する\n",
        "with assert_raises(ValueError):\n",
        "  next_collatz(tf.constant([[1, 2], [3, 4]]))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Es0WZkLIUSdu"
      },
      "source": [
        "## いつ再トレースするのか？\n",
        "\n",
        "ポリモーフィックな `tf.function` はトレーシングによって生成された具象関数のキャッシュを保持しています。キャッシュのキーは、実際にはその関数の引数及びキーワード引数から生成されたキーのタプルです。`tf.Tensor` 引数から生成されるキーは、テンソルの shape と型です。Python の組み込み型引数から生成されるキーはその値です。それ以外の Python の型では、キーはオブジェクトの `id()` に基づいており、メソッドはクラスのインスタンスひとつずつ独立にトレースされます。将来、TensorFlowには、Python オブジェクトについて安全にテンソルに変換できるような、より洗練されたキャッシングが追加されるかもしれません。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AY5oiQN0XIyA"
      },
      "source": [
        "## 引数は Python か？ Tensor か？\n",
        "\n",
        "しばしば、ハイパーパラメータやグラフ構成を制御するために Python の組み込み型の引数が使われます。例えば、`num_layers=10` や `training=True` あるいは `nonlinearity='relu'` のようにです。このため、この Python の組み込み型の引数が変更されると、計算グラフを再びトレースする必要があるということになります。\n",
        "\n",
        "しかし、グラフの生成を制御するために Python の組み込み型の引数を使用する必要はありません。これらのケースでは、Python引数の値の変更が不必要な再トレースを引き起こす可能性があります。例えば、この訓練ループでは、AutoGraph は動的に展開を行います。複数回トレースを行っていますが、生成される計算グラフは全く変わりません。これは少し非効率です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "uydzR5JYUU8H",
        "colab": {}
      },
      "source": [
        "def train_one_step():\n",
        "  pass\n",
        "\n",
        "@tf.function\n",
        "def train(num_steps):\n",
        "  print(\"Tracing with num_steps = {}\".format(num_steps))\n",
        "  for _ in tf.range(num_steps):\n",
        "    train_one_step()\n",
        "\n",
        "train(num_steps=10)\n",
        "train(num_steps=20)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "f6pjnylLUW8P"
      },
      "source": [
        "ここでの簡単な回避方法は、生成されたグラフの shape が変わらないのであれば、引数をテンソルにキャストすることです。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TmL8T-w3UYes",
        "colab": {}
      },
      "source": [
        "train(num_steps=tf.constant(10))\n",
        "train(num_steps=tf.constant(20))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "129-iRsPS-gY"
      },
      "source": [
        "## `tf.function` の中の副作用\n",
        "\n",
        "一般的には、（印字やオブジェクト変更のような）Python の副作用は、トレーシングの最中にだけ発生します。それでは、どうしたら `tf.function` で安定的に副作用を起こすことができるでしょうか？\n",
        "\n",
        "一般的な原則は、トレースをデバッグする際にだけ Python の副作用を使用するというものです。あるいは、`tf.Variable.assign`、`tf.print`、そして `tf.summary` のような TensorFlow の演算を使うことで、コードがトレースされるときにも、TensorFlowランタイムによって都度呼び出される際にも、確実に実行されるようにできます。一般には、関数型のスタイルを使用することで最も良い結果を得られます。 "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "w2sACuZ9TTRk",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def f(x):\n",
        "  print(\"Traced with\", x)\n",
        "  tf.print(\"Executed with\", x)\n",
        "\n",
        "f(1)\n",
        "f(1)\n",
        "f(2)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "e1I0dPiqTV8H"
      },
      "source": [
        "`tf.function` が呼び出されるたびに Python のコードを実行したいのであれば、`tf.py_function` がぴったりです。`tf.py_function` の欠点は、ポータブルでないこと、それほど性能が高くないこと、（マルチGPU、TPUの）分散環境ではうまく動作しないことなどです。また、`tf.py_function` は計算グラフに組み込まれるため、入出力すべてをテンソルにキャストします。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "7aJD--9qTWmg",
        "colab": {}
      },
      "source": [
        "external_list = []\n",
        "\n",
        "def side_effect(x):\n",
        "  print('Python side effect')\n",
        "  external_list.append(x)\n",
        "\n",
        "@tf.function\n",
        "def f(x):\n",
        "  tf.py_function(side_effect, inp=[x], Tout=[])\n",
        "\n",
        "f(1)\n",
        "f(1)\n",
        "f(1)\n",
        "assert len(external_list) == 3\n",
        "# .numpy() call required because py_function casts 1 to tf.constant(1)\n",
        "assert external_list[0].numpy() == 1\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "msTmv-oyUNaf"
      },
      "source": [
        "## Python の状態に注意\n",
        "\n",
        "ジェネレーターやイテレーターなど Python の機能の多くは、状態を追跡するために Python のランタイムに依存しています。これらの仕組みは、一般的には Eager モードでも期待通りに動作しますが、トレーシングの振る舞いにより、`tf.function` の中では予期しないことが起きることがあります。\n",
        "\n",
        "1例として、イテレーターの状態が進むのは Python の副作用であり、トレーシングの中だけで発生します。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "FNPD4unZUedH",
        "colab": {}
      },
      "source": [
        "external_var = tf.Variable(0)\n",
        "@tf.function\n",
        "def buggy_consume_next(iterator):\n",
        "  external_var.assign_add(next(iterator))\n",
        "  tf.print(\"Value of external_var:\", external_var)\n",
        "\n",
        "iterator = iter([0, 1, 2, 3])\n",
        "buggy_consume_next(iterator)\n",
        "# 次のコードは、イテレーターの次の値を使うのではなく、最初の値を再利用する\n",
        "buggy_consume_next(iterator)\n",
        "buggy_consume_next(iterator)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5XMGXMu-Ufjm"
      },
      "source": [
        "イテレーターが tf.function の中で生成されすべて使われる場合には、正しく動作するはずです。しかし、イテレーター全体がトレースされることとなり、巨大な計算グラフの生成をまねく可能性があります。これは、望みどおりの動作かもしれません。しかし、もし Python のリストとして表されたメモリー上の巨大なデータセットを使って訓練を行うとすると、これは非常に大きな計算グラフを生成することになり、`tf.function` がスピードアップにはつながらないと考えられます。\n",
        "\n",
        "Python データを繰り返し使用する場合、もっとも安全な方法は tf.data.Dataset でラップして、`for x in y` というイディオムを使用することです。AutoGraph には、`y` がテンソルあるいは tf.data.Dataset である場合、`for` ループを安全に変換する特別な機能があります。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ms7f1o_QUiHE",
        "colab": {}
      },
      "source": [
        "def measure_graph_size(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  print(\"{}({}) contains {} nodes in its graph\".format(\n",
        "      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n",
        "\n",
        "@tf.function\n",
        "def train(dataset):\n",
        "  loss = tf.constant(0)\n",
        "  for x, y in dataset:\n",
        "    loss += tf.abs(y - x) # ダミー計算\n",
        "  return loss\n",
        "\n",
        "small_data = [(1, 1)] * 2\n",
        "big_data = [(1, 1)] * 10\n",
        "measure_graph_size(train, small_data)\n",
        "measure_graph_size(train, big_data)\n",
        "\n",
        "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
        "    lambda: small_data, (tf.int32, tf.int32)))\n",
        "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
        "    lambda: big_data, (tf.int32, tf.int32)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dGDstsFpWHEI"
      },
      "source": [
        "Python/Numpy のデータを Dataset でラップする際には、`tf.data.Dataset.from_generator` と `tf.data.Dataset.from_tensors` の違いに留意しましょう。前者はデータを Python のまま保持し `tf.py_function` を通じて取得するため、性能に影響する場合があります。これに対して後者はデータのコピーを計算グラフの中の、ひとつの大きな `tf.constant()` に結びつけるため、メモリー消費に影響する可能性があります。 \n",
        "\n",
        "TFRecordDataset/CsvDataset/などを通じてデータをファイルから読み込むことが、データを使用する最も効率的な方法です。TensorFlow 自身が Python とは関係なく非同期のデータ読み込みとプリフェッチを管理することができるからです。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tRdlnCfV_UTn"
      },
      "source": [
        "## 自動的な依存関係の制御\n",
        "\n",
        "プログラミングモデルとしての関数が一般的なデータフローグラフに対して非常に優位である点は、意図したコードの振る舞いがどのようなものであるかということについて、より多くの情報をランタイムに与えられるということにあります。\n",
        "\n",
        "例えば、同じ変数を何度も読んだり書いたりするコードを書く場合、データフローグラフではもともと意図されていた演算の順番を自然に組み込むわけではありません。`tf.function` の中では、もともとの Python コードの文の実行順序を参照することで、実行順序の曖昧さを解消します。これにより、`tf.function` の中のステートフルな演算の順序が、先行実行モードのセマンティクスを模していることになります。\n",
        "\n",
        "これは、手動で制御の依存関係を加える必要がないことを意味しています。`tf.function` は十分賢いので、あなたのコードが正しく動作するために必要十分な最小限の制御の依存関係を追加してくれます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "SASm0ss8erVX",
        "colab": {}
      },
      "source": [
        "# 自動的な依存関係の制御\n",
        "\n",
        "a = tf.Variable(1.0)\n",
        "b = tf.Variable(2.0)\n",
        "\n",
        "@tf.function\n",
        "def f(x, y):\n",
        "  a.assign(y * b)\n",
        "  b.assign_add(x * a)\n",
        "  return a + b\n",
        "\n",
        "f(1.0, 2.0)  # 10.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lPr_6mK_AQWL"
      },
      "source": [
        "## 変数\n",
        "\n",
        "`tf.function` の中では、意図したコードの実行順序を活用するという同じアイデアを使って、変数の作成と活用を簡単に行うことができます。しかし、ひとつだけ非常に重要な欠点があります。それは、変数を使った場合、先行実行モードとグラフモードでは動作が変わるコードを書いてしまう可能性があるということです。\n",
        "\n",
        "特に、呼び出しの都度新しい変数を作成する場合にこれが発生します。トレーシングの意味では、`tf.function` は呼び出しのたびに同じ変数を再利用しますが、Eager モードでは呼び出しごとに新しい変数を生成します。この間違いを防止するため、`tf.function` は危険な変数の生成動作を見つけるとエラーを発生させます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Tx0Vvnb_9OB-",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def f(x):\n",
        "  v = tf.Variable(1.0)\n",
        "  v.assign_add(x)\n",
        "  return v\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  f(1.0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "DKzNjVg8h4ao",
        "colab": {}
      },
      "source": [
        "# しかし、曖昧さの無いコードは大丈夫\n",
        "\n",
        "v = tf.Variable(1.0)\n",
        "\n",
        "@tf.function\n",
        "def f(x):\n",
        "  return v.assign_add(x)\n",
        "\n",
        "print(f(1.0))  # 2.0\n",
        "print(f(2.0))  # 4.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HQrG5_kOiKl_",
        "colab": {}
      },
      "source": [
        "# 初めて関数が実行されるときだけ変数が生成されることを保証できれば\n",
        "# tf.function 内で変数を作成できる\n",
        "\n",
        "class C: pass\n",
        "obj = C(); obj.v = None\n",
        "\n",
        "@tf.function\n",
        "def g(x):\n",
        "  if obj.v is None:\n",
        "    obj.v = tf.Variable(1.0)\n",
        "  return obj.v.assign_add(x)\n",
        "\n",
        "print(g(1.0))  # 2.0\n",
        "print(g(2.0))  # 4.0"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "_IOVc1eujMH2",
        "colab": {}
      },
      "source": [
        "# 変数の初期化は、関数の引数や他の変数の値に依存可能\n",
        "# 制御の依存関係を生成するのと同じ手法で、正しい初期化の順序を発見可能\n",
        "\n",
        "state = []\n",
        "@tf.function\n",
        "def fn(x):\n",
        "  if not state:\n",
        "    state.append(tf.Variable(2.0 * x))\n",
        "    state.append(tf.Variable(state[0] * 3.0))\n",
        "  return state[0] * x * state[1]\n",
        "\n",
        "print(fn(tf.constant(1.0)))\n",
        "print(fn(tf.constant(3.0)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5f05Vr_YBUCz"
      },
      "source": [
        "# AutoGraph の使用\n",
        "\n",
        "[autograph](https://www.tensorflow.org/guide/function) ライブラリは `tf.function` に完全に統合されており、計算グラフの中で動的に実行される条件文や繰り返しを書くことができます。\n",
        "\n",
        "`tf.cond` や `tf.while_loop` は `tf.function` でも使えますが、制御フローを含むコードは、命令形式で書いたほうが書きやすいし理解しやすいです。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "yCQTtTPTW3WF",
        "colab": {}
      },
      "source": [
        "# 単純な繰り返し\n",
        "\n",
        "@tf.function\n",
        "def f(x):\n",
        "  while tf.reduce_sum(x) > 1:\n",
        "    tf.print(x)\n",
        "    x = tf.tanh(x)\n",
        "  return x\n",
        "\n",
        "f(tf.random.uniform([5]))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "jlQD1ffRXJhl",
        "colab": {}
      },
      "source": [
        "# 興味があれば AutoGraph が生成するコードを調べることができる\n",
        "# ただし、アセンブリ言語を読むような感じがする\n",
        "\n",
        "def f(x):\n",
        "  while tf.reduce_sum(x) > 1:\n",
        "    tf.print(x)\n",
        "    x = tf.tanh(x)\n",
        "  return x\n",
        "\n",
        "print(tf.autograph.to_code(f))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xgKmkrNTZSyz"
      },
      "source": [
        "## AutoGraph: 条件分岐\n",
        "\n",
        "AutoGraph は `if` 文を等価である `tf.cond` の呼び出しに変換します。\n",
        "\n",
        "この置換は条件がテンソルである場合に行われます。そうでない場合には、条件分岐はトレーシングの中で実行されます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "E-7KllizZYsy",
        "colab": {}
      },
      "source": [
        "def test_tf_cond(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  if any(node.name == 'cond' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.cond.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  else:\n",
        "    print(\"{}({}) executes normally.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "o86paGR-Zadi",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def hyperparam_cond(x, training=True):\n",
        "  if training:\n",
        "    x = tf.nn.dropout(x, rate=0.5)\n",
        "  return x\n",
        "\n",
        "@tf.function\n",
        "def maybe_tensor_cond(x):\n",
        "  if x < 0:\n",
        "    x = -x\n",
        "  return x\n",
        "\n",
        "test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))\n",
        "test_tf_cond(maybe_tensor_cond, tf.constant(-1))\n",
        "test_tf_cond(maybe_tensor_cond, -1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5xFLfdApZh8q"
      },
      "source": [
        "`tf.cond` には、色々と注意すべき細かな点があります。\n",
        "\n",
        "- `tf.cond` は条件分岐の両方をトレーシングし、条件に従って実行時に適切な分岐を選択することで機能します。分岐の両方をトレースすることで、Python プログラムを予期せず実行する可能性があります。\n",
        "- `tf.cond` では、分岐の一方が後ほど使用されるテンソルを作成する場合、もう一方の分岐もそのテンソルを作成することが必要です。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VTMoZEVaZiwk",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def f():\n",
        "  x = tf.constant(0)\n",
        "  if tf.constant(True):\n",
        "    x = x + 1\n",
        "    print(\"Tracing `then` branch\")\n",
        "  else:\n",
        "    x = x - 1\n",
        "    print(\"Tracing `else` branch\")\n",
        "  return x\n",
        "\n",
        "f()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "k_dxWHeFZlaQ",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def f():\n",
        "  if tf.constant(True):\n",
        "    x = tf.ones([3, 3])\n",
        "  return x\n",
        "\n",
        "# 分岐のどちらの枝でも `x` を定義する必要があるためエラーが発生\n",
        "with assert_raises(ValueError):\n",
        "  f()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yho4J0a0ZkQS"
      },
      "source": [
        "## AutoGraph と繰り返し\n",
        "\n",
        "AutoGraph には繰り返しの変換にいくつかの単純なルールがあります。\n",
        "\n",
        "- `for`: 反復可能オブジェクトがテンソルである場合に変換する\n",
        "- `while`: while 条件がテンソルに依存している場合に変換する\n",
        "\n",
        "繰り返しが変換される場合、`tf.while_loop` によって動的に展開されます。あるいは、 `for x in tf.data.Dataset` という特別なケースの場合には、 `tf.data.Dataset.reduce` に変換されます。\n",
        "\n",
        "繰り返しが変換されない場合、それは静的に展開されます。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "OyzGNQAuZsky",
        "colab": {}
      },
      "source": [
        "def test_dynamically_unrolled(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  if any(node.name == 'while' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.while_loop.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.data.Dataset.reduce.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  else:\n",
        "    print(\"{}({}) gets unrolled.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Q7tmncQTZt6_",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def for_in_range():\n",
        "  x = 0\n",
        "  for i in range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_range)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "H-SM-c-qTuoX",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def for_in_tfrange():\n",
        "  x = tf.constant(0, dtype=tf.int32)\n",
        "  for i in tf.range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_tfrange)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "abXQ2iwBTuoZ",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def for_in_tfdataset():\n",
        "  x = tf.constant(0, dtype=tf.int64)\n",
        "  for i in tf.data.Dataset.range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_tfdataset)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "l6s7aU-padY5",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def while_py_cond():\n",
        "  x = 5\n",
        "  while x > 0:\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_py_cond)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "QCLFaZnxTuoc",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def while_tf_cond():\n",
        "  x = tf.constant(5)\n",
        "  while x > 0:\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_tf_cond)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dSr64Xn6ap-S"
      },
      "source": [
        " 繰り返しに、テンソルに依存する `break` や、途中での `return` がある場合、一番外側の条件あるいは反復可能オブジェクトはテンソルである必要があります。 \n",
        " \n",
        " 比較してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s2eaWCe2Tuof",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def while_py_true_py_break(x):\n",
        "  while True:  # py true\n",
        "    if x == 0: # py break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_py_true_py_break, 5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Kx2Z0K_uTuog",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def buggy_while_py_true_tf_break(x):\n",
        "  while True:   # py true\n",
        "    if tf.equal(x, 0): # tf break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-ozYU9zQTuoj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def while_tf_true_tf_break(x):\n",
        "  while tf.constant(True): # tf true\n",
        "    if x == 0:  # py break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_tf_true_tf_break, 5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bZSpBq5CTuol",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def buggy_py_for_tf_break():\n",
        "  x = 0\n",
        "  for i in range(5):  # py for\n",
        "    if tf.equal(i, 3): # tf break\n",
        "      break\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  test_dynamically_unrolled(buggy_py_for_tf_break)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DT9zyKTwTuon",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def tf_for_py_break():\n",
        "  x = 0\n",
        "  for i in tf.range(5): # tf for\n",
        "    if i == 3:  # py break\n",
        "      break\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(tf_for_py_break)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hyksHW9TCukR"
      },
      "source": [
        "動的に展開される繰り返しの結果を集計するため、`tf.TensorArray` を使いたくなるかもしれません。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "HJ3Vb3dXfefN",
        "colab": {}
      },
      "source": [
        "batch_size = 2\n",
        "seq_len = 3\n",
        "feature_size = 4\n",
        "\n",
        "def rnn_step(inp, state):\n",
        "  return inp + state\n",
        "\n",
        "@tf.function\n",
        "def dynamic_rnn(rnn_step, input_data, initial_state):\n",
        "  # [batch, time, features] -> [time, batch, features]\n",
        "  input_data = tf.transpose(input_data, [1, 0, 2])\n",
        "  max_seq_len = input_data.shape[0]\n",
        "\n",
        "  states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
        "  state = initial_state\n",
        "  for i in tf.range(max_seq_len):\n",
        "    state = rnn_step(input_data[i], state)\n",
        "    states = states.write(i, state)\n",
        "  return tf.transpose(states.stack(), [1, 0, 2])\n",
        "  \n",
        "dynamic_rnn(rnn_step,\n",
        "            tf.random.uniform([batch_size, seq_len, feature_size]),\n",
        "            tf.zeros([batch_size, feature_size]))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9gmLpHY-bkly"
      },
      "source": [
        "`tf.cond` と同様に、`tf.while_loop` にも、色々と注意すべき細かな点があります。\n",
        "\n",
        "- 繰り返しの実行回数が 0 である可能性があるため、while_loop の後で使用されるテンソルは、繰り返しの前に初期化されなければならない\n",
        "- すべての繰り返しの変数は、各繰り返しを通じてその形状と dtype が変わらないことが必要"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "CocT5RHwblrQ",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def buggy_loop_var_uninitialized():\n",
        "  for i in tf.range(3):\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  buggy_loop_var_uninitialized()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BnfgidIhTuow",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def f():\n",
        "  x = tf.constant(0)\n",
        "  for i in tf.range(3):\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "f()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L_0qnF58Tuoy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def buggy_loop_type_changes():\n",
        "  x = tf.constant(0, dtype=tf.float32)\n",
        "  for i in tf.range(3): # tf.int32 型のテンソルを1つづつ取り出して…\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "with assert_raises(tf.errors.InvalidArgumentError):\n",
        "  buggy_loop_type_changes()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "kWF189prbuK0",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def buggy_concat():\n",
        "  x = tf.ones([0, 10])\n",
        "  for i in tf.range(5):\n",
        "    x = tf.concat([x, tf.ones([1, 10])], axis=0)\n",
        "  return x\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  buggy_concat()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VO6QXZGlTuo4",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def concat_with_padding():\n",
        "  x = tf.zeros([5, 10])\n",
        "  for i in tf.range(5):\n",
        "    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)\n",
        "    x.set_shape([5, 10])\n",
        "  return x\n",
        "\n",
        "concat_with_padding()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}